{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {},
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/master/tutorials/W3D2_HiddenDynamics/W3D2_Tutorial1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a> &nbsp; <a href=\"https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/NeoNeuron/professional-workshop-3/master/tutorials/W8_HiddenDynamics/W8_Tutorial1.ipynb\" target=\"_parent\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" alt=\"Open in Kaggle\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "# Tutorial 1: Sequential Probability Ratio Test\n",
    "\n",
    "__Content creators:__ Yicheng Fei and Xaq Pitkow\n",
    "\n",
    "__Content modified:__ Kai Chen\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "# Tutorial Objectives\n",
    "\n",
    "On Bayes Day, we learned how to combine the sensory measurement $m$ about a latent variable $s$ with and our prior knowledge, using Bayes' Theorem. This produced a posterior probability distribution $p(s|m)$. Today we will allow for _dynamic_ world states and measurements.\n",
    "\n",
    "In Tutorial 1 we will assume that the world state is _binary_ ($\\pm 1$) and _constant_ over time, but allow for multiple observations over time. We will use the *Sequential Probability Ratio Test* (SPRT) to infer which state is true. This leads to the *Drift Diffusion Model (DDM)* where evidence accumulates until reaching a stopping criterion.\n",
    "\n",
    "By the end of this tutorial, you should be able to:\n",
    "- Define and implement the Sequential Probability Ratio Test for a series of measurements\n",
    "- Define what drift and diffusion mean in a drift-diffusion model\n",
    "- Explain the speed-accuracy trade-off in a drift diffusion model\n",
    "\n",
    "**Summary of Exercises**\n",
    "\n",
    "0. Bonus (math): derive the Drift Diffusion Model mathematically from SPRT\n",
    "\n",
    "1. Simulate the DDM\n",
    "    1. _Code_: Accumulate evidence and make a decision (DDM)\n",
    "    2. _Interactive_: Manipulate parameters and interpret\n",
    "\n",
    "2. Analyze the DDM\n",
    "    1. _Code_: Quantify speed-accuracy tradeoff\n",
    "    2. _Interactive_: Manipulate parameters and interpret"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# Imports\n",
    "\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.special import erf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @title Figure Settings\n",
    "import ipywidgets as widgets       # interactive display\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "nma_style = {\n",
    "    'figure.figsize' : (8, 6),\n",
    "    'figure.autolayout' : True,\n",
    "    'font.size' : 15,\n",
    "    'xtick.labelsize' : 'small',\n",
    "    'ytick.labelsize' : 'small',\n",
    "    'legend.fontsize' : 'small',\n",
    "    'axes.spines.top' : False,\n",
    "    'axes.spines.right' : False,\n",
    "    'xtick.major.size' : 5,\n",
    "    'ytick.major.size' : 5,\n",
    "}\n",
    "for key, value in nma_style.items():\n",
    "    plt.rcParams[key] = value\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @title Helper Functions\n",
    "\n",
    "def simulate_and_plot_SPRT_fixedtime(mu, sigma, stop_time, num_sample,\n",
    "                                     verbose=True):\n",
    "  \"\"\"Simulate and plot a SPRT for a fixed amount of time given a std.\n",
    "\n",
    "  Args:\n",
    "    mu (float): absolute mean value of the symmetric observation distributions\n",
    "    sigma (float): Standard deviation of the observations.\n",
    "    stop_time (int): Number of steps to run before stopping.\n",
    "    num_sample (int): The number of samples to plot.\n",
    "    \"\"\"\n",
    "\n",
    "  evidence_history_list = []\n",
    "  if verbose:\n",
    "    print(\"#Trial\\tTotal_Evidence\\tDecision\")\n",
    "  for i in range(num_sample):\n",
    "    evidence_history, decision, Mvec = simulate_SPRT_fixedtime(mu, sigma, stop_time)\n",
    "    if verbose:\n",
    "      print(\"{}\\t{:f}\\t{}\".format(i, evidence_history[-1], decision))\n",
    "    evidence_history_list.append(evidence_history)\n",
    "\n",
    "  fig, ax = plt.subplots()\n",
    "  maxlen_evidence = np.max(list(map(len,evidence_history_list)))\n",
    "  ax.plot(np.zeros(maxlen_evidence), '--', c='red', alpha=1.0)\n",
    "  for evidences in evidence_history_list:\n",
    "    ax.plot(np.arange(len(evidences)), evidences)\n",
    "    ax.set_xlabel(\"Time\")\n",
    "    ax.set_ylabel(\"Accumulated log likelihood ratio\")\n",
    "    ax.set_title(\"Log likelihood ratio trajectories under the fixed-time \" +\n",
    "                  \"stopping rule\")\n",
    "\n",
    "  plt.show(fig)\n",
    "\n",
    "\n",
    "def plot_accuracy_vs_stoptime(mu, sigma, stop_time_list, accuracy_analytical_list, accuracy_list=None):\n",
    "  \"\"\"Simulate and plot a SPRT for a fixed amount of times given a std.\n",
    "\n",
    "  Args:\n",
    "    mu (float): absolute mean value of the symmetric observation distributions\n",
    "    sigma (float): Standard deviation of the observations.\n",
    "    stop_time_list (int): List of number of steps to run before stopping.\n",
    "    accuracy_analytical_list (int): List of analytical accuracies for each stop time\n",
    "    accuracy_list (int (optional)): List of simulated accuracies for each stop time\n",
    "  \"\"\"\n",
    "  T = stop_time_list[-1]\n",
    "  fig, ax = plt.subplots(figsize=(12,8))\n",
    "  ax.set_xlabel('Stop Time')\n",
    "  ax.set_ylabel('Average Accuracy')\n",
    "  ax.plot(stop_time_list, accuracy_analytical_list)\n",
    "  if accuracy_list is not None:\n",
    "    ax.plot(stop_time_list, accuracy_list)\n",
    "  ax.legend(['analytical','simulated'], loc='upper center')\n",
    "\n",
    "  # Show two gaussian\n",
    "  stop_time_list_plot = [max(1,T//10), T*2//3]\n",
    "  sigma_st_max = 2*mu*np.sqrt(stop_time_list_plot[-1])/sigma\n",
    "  domain = np.linspace(-3*sigma_st_max,3*sigma_st_max,50)\n",
    "  for stop_time in stop_time_list_plot:\n",
    "    ins = ax.inset_axes([stop_time/T,0.05,0.2,0.3])\n",
    "    for pos in ['right', 'top', 'bottom', 'left']:\n",
    "      ins.spines[pos].set_visible(False)\n",
    "    ins.axis('off')\n",
    "    ins.set_title(f\"stop_time={stop_time}\")\n",
    "\n",
    "    left = np.zeros_like(domain)\n",
    "    mu_st = 4*mu*mu*stop_time/2/sigma**2\n",
    "    sigma_st = 2*mu*np.sqrt(stop_time)/sigma\n",
    "    for i, mu1 in enumerate([-mu_st,mu_st]):\n",
    "      rv = stats.norm(mu1, sigma_st)\n",
    "      offset = rv.pdf(domain)\n",
    "      # lbl = \"measurement distribution\" if i==0 else \"\"\n",
    "      lbl = \"summed evidence\" if i==1 else \"\"\n",
    "      color = \"crimson\"\n",
    "      ls = \"solid\" if i==1 else \"dashed\"\n",
    "      ins.plot(domain, left+offset, label=lbl, color=color,ls=ls)\n",
    "\n",
    "    rv = stats.norm(mu_st, sigma_st)\n",
    "    domain0 = np.linspace(-3*sigma_st_max,0,50)\n",
    "    offset = rv.pdf(domain0)\n",
    "    ins.fill_between(domain0, np.zeros_like(domain0), offset, color=\"crimson\", label=\"error\")\n",
    "    ins.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')\n",
    "\n",
    "\n",
    "    # ins.legend(loc=\"upper right\")\n",
    "\n",
    "  plt.show(fig)\n",
    "\n",
    "\n",
    "def simulate_and_plot_SPRT_fixedthreshold(mu, sigma, num_sample, alpha,\n",
    "                                          verbose=True):\n",
    "  \"\"\"Simulate and plot a SPRT for a fixed amount of times given a std.\n",
    "\n",
    "  Args:\n",
    "    mu (float): absolute mean value of the symmetric observation distributions\n",
    "    sigma (float): Standard deviation of the observations.\n",
    "    num_sample (int): The number of samples to plot.\n",
    "    alpha (float): Threshold for making a decision.\n",
    "  \"\"\"\n",
    "  # calculate evidence threshold from error rate\n",
    "  threshold = threshold_from_errorrate(alpha)\n",
    "\n",
    "  # run simulation\n",
    "  evidence_history_list = []\n",
    "  if verbose:\n",
    "    print(\"#Trial\\tTime\\tAccumulated Evidence\\tDecision\")\n",
    "  for i in range(num_sample):\n",
    "    evidence_history, decision, Mvec = simulate_SPRT_threshold(mu, sigma, threshold)\n",
    "    if verbose:\n",
    "      print(\"{}\\t{}\\t{:f}\\t{}\".format(i, len(Mvec), evidence_history[-1],\n",
    "                                      decision))\n",
    "    evidence_history_list.append(evidence_history)\n",
    "\n",
    "  fig, ax = plt.subplots()\n",
    "  maxlen_evidence = np.max(list(map(len,evidence_history_list)))\n",
    "  ax.plot(np.repeat(threshold,maxlen_evidence + 1), c=\"red\")\n",
    "  ax.plot(-np.repeat(threshold,maxlen_evidence + 1), c=\"red\")\n",
    "  ax.plot(np.zeros(maxlen_evidence + 1), '--', c='red', alpha=0.5)\n",
    "\n",
    "  for evidences in evidence_history_list:\n",
    "      ax.plot(np.arange(len(evidences) + 1), np.concatenate([[0], evidences]))\n",
    "\n",
    "  ax.set_xlabel(\"Time\")\n",
    "  ax.set_ylabel(\"Accumulated log likelihood ratio\")\n",
    "  ax.set_title(\"Log likelihood ratio trajectories under the threshold rule\")\n",
    "\n",
    "  plt.show(fig)\n",
    "\n",
    "\n",
    "def simulate_and_plot_accuracy_vs_threshold(mu, sigma, threshold_list, num_sample):\n",
    "  \"\"\"Simulate and plot a SPRT for a set of thresholds given a std.\n",
    "\n",
    "  Args:\n",
    "    mu (float): absolute mean value of the symmetric observation distributions\n",
    "    sigma (float): Standard deviation of the observations.\n",
    "    alpha_list (float): List of thresholds for making a decision.\n",
    "    num_sample (int): The number of samples to plot.\n",
    "  \"\"\"\n",
    "  accuracies, decision_speeds = simulate_accuracy_vs_threshold(mu, sigma,\n",
    "                                                               threshold_list,\n",
    "                                                               num_sample)\n",
    "\n",
    "  # Plotting\n",
    "  fig, ax = plt.subplots()\n",
    "  ax.plot(decision_speeds, accuracies, linestyle=\"--\", marker=\"o\")\n",
    "  ax.plot([np.amin(decision_speeds), np.amax(decision_speeds)],\n",
    "          [0.5, 0.5], c='red')\n",
    "  ax.set_xlabel(\"Average Decision speed\")\n",
    "  ax.set_ylabel('Average Accuracy')\n",
    "  ax.set_title(\"Speed/Accuracy Tradeoff\")\n",
    "  ax.set_ylim(0.45, 1.05)\n",
    "\n",
    "  plt.show(fig)\n",
    "\n",
    "\n",
    "def threshold_from_errorrate(alpha):\n",
    "  \"\"\"Calculate log likelihood ratio threshold from desired error rate `alpha`\n",
    "\n",
    "  Args:\n",
    "    alpha (float): in (0,1), the desired error rate\n",
    "\n",
    "  Return:\n",
    "    threshold: corresponding evidence threshold\n",
    "  \"\"\"\n",
    "  threshold = np.log((1. - alpha) / alpha)\n",
    "  return threshold"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "\n",
    "# Section 1: Sequential Probability Ratio Test as a Drift Diffusion Model\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "**Sequential Probability Ratio Test**\n",
    "\n",
    "The Sequential Probability Ratio Test is a likelihood ratio test for determining which of two hypotheses is more likely. It is appropriate for sequential independent and identially distributed (iid) data. iid means that the data comes from the same distribution.\n",
    "\n",
    "Let's return to what we learned yesterday. We had probabilities of our measurement ($m$) given a state of the world ($s$). For example, we knew the probability of seeing someone catch a fish while fishing on the left side given that the fish were on the left side $P(m = \\textrm{catch fish} | s = \\textrm{left})$.\n",
    "\n",
    "Now let's extend this slightly to assume we take a series of measurements, from time 1 up to time t ($m_{1:t}$), and that our state is either +1 or -1. We want to figure out what the state is, given our measurements. To do this, we can compare the total evidence up to time $t$ for our two hypotheses (that the state is +1 or that the state is -1). We do this by computing a likelihood ratio: the ratio of the likelihood of all these measurements given the state is +1, $p(m_{1:t}|s=+1)$, to the likelihood of the measurements given the state is -1, $p(m_{1:t}|s=-1)$. This is our likelihood ratio test. In fact, we want to take the log of this likelihood ratio to give us the log likelihood ratio $L_T$.\n",
    "\n",
    "\\begin{align*}\n",
    "L_T &= log\\frac{p(m_{1:t}|s=+1)}{p(m_{1:t}|s=-1)}\n",
    "\\end{align*}\n",
    "\n",
    "Since our data is independent and identically distribution, the probability of all measurements given the state equals the product of the separate probabilities of each measurement given the state ($p(m_{1:t}|s) = \\prod_{t=1}^T p(m_t | s) $). We can substitute this in and use log properties to convert to a sum.\n",
    "\n",
    "\\begin{align*}\n",
    "L_T &= log\\frac{p(m_{1:t}|s=+1)}{p(m_{1:t}|s=-1)}\\\\\n",
    "&= log\\frac{\\prod_{t=1}^Tp(m_{t}|s=+1)}{\\prod_{t=1}^Tp(m_{t}|s=-1)}\\\\\n",
    "&= \\sum_{t=1}^T log\\frac{p(m_{t}|s=+1)}{p(m_{t}|s=-1)}\\\\\n",
    "&= \\sum_{t=1}^T \\Delta_t\n",
    "\\end{align*}\n",
    "\n",
    "In the last line, we have used $\\Delta_t = log\\frac{p(m_{t}|s=+1)}{p(m_{t}|s=-1)}$. \n",
    "\n",
    "To get the full log likelihood ratio, we are summing up the log likelihood ratios at each time step. The log likelihood ratio at a time step ($L_T$) will equal the ratio at the previous time step ($L_{T-1}$) plus the ratio for the measurement at that time step, given by $\\Delta_T$:\n",
    "\n",
    "\\begin{align*}\n",
    "L_T =  L_{T-1} + \\Delta_T\n",
    "\\end{align*}\n",
    "\n",
    "The SPRT states that if $L_T$ is positive, then the state $s=+1$ is more likely than $s=-1$! \n",
    "\n",
    "\n",
    "**Sequential Probability Ratio Test as a Drift Diffusion Model**\n",
    "\n",
    "Let's assume that the probability of seeing a measurement given the state is a Gaussian (Normal) distribution where the mean ($\\mu$) is different for the two states but the standard deviation ($\\sigma$) is the same:\n",
    "\n",
    "\\begin{align*}\n",
    "p(m_t | s = +1) &= \\mathcal{N}(\\mu, \\sigma^2)\\\\\n",
    "p(m_t | s = -1) &= \\mathcal{N}(-\\mu, \\sigma^2)\\\\\n",
    "\\end{align*}\n",
    "\n",
    "We can write the new evidence (the log likelihood ratio for the measurement at time $t$) as\n",
    "\n",
    "$$\\Delta_t=b+c\\epsilon_t$$\n",
    "\n",
    "The first term, $b$, is a consistant value and equals $b=2\\mu^2/\\sigma^2$. This term favors the actual hidden state. The second term, $c\\epsilon_t$ where $\\epsilon_t\\sim\\mathcal{N}(0,1)$, is a standard random variable which is scaled by the diffusion $c=2\\mu/\\sigma$. You can work through proving this in the bonus exercise 0 below if you wish!\n",
    "\n",
    "The accumulation of evidence will thus \"drift\" toward one outcome, while \"diffusing\" in random directions, hence the term \"drift-diffusion model\" (DDM). The process is most likely (but not guaranteed) to reach the correct outcome eventually.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "<details>\n",
    "<summary><font color='blue'>Bonus math exercise 0: derive Drift Diffusion Model from SPRT</font>\n",
    "</summary>\n",
    "\n",
    "We can do a little math to find the SPRT update $\\Delta_t$ to the log-likelihood ratio. You can derive this yourself, filling in the steps below, or skip to the end result.\n",
    "\n",
    "Assume measurements are Gaussian-distributed with different means depending on the discrete latent variable $s$:\n",
    "$$p(m|s=\\pm 1) = \\mathcal{N}\\left(\\mu_\\pm,\\sigma^2\\right)=\\frac{1}{\\sqrt{2\\pi\\sigma^2}}\\exp{\\left[-\\frac{(m-\\mu_\\pm)^2}{2\\sigma^2}\\right]}$$\n",
    "\n",
    "In the log likelihood ratio for a single data point $m_i$, the normalizations cancel to give\n",
    "$$\\Delta_t=\\log \\frac{p(m_t|s=+1)}{p(m_t|s=-1)} = \\frac{1}{2\\sigma^2}\\left[-\\left(m_t-\\mu_+\\right)^2 + (m_t-\\mu_-)^2\\right] \\tag{5}$$\n",
    "\n",
    "It's convenient to rewrite $m=\\mu_\\pm + \\sigma \\epsilon$, where $\\epsilon\\sim \\mathcal{N}(0,1)$ is a standard Gaussian variable with zero mean and unit variance. (Why does this give the correct probability for $m$?). The preceding formula can then be rewritten as \n",
    "$$\\Delta_t = \\frac{1}{2\\sigma^2}\\left( -((\\mu_\\pm+\\sigma\\epsilon)-\\mu_+)^2 + ((\\mu_\\pm+\\sigma\\epsilon)-\\mu_-)^2\\right) \\tag{5}$$\n",
    "Let's assume that $s=+1$ so $\\mu_\\pm=\\mu_+$ (if $s=-1$ then the result is the same with a reversed sign). In that case, the means in the first term $m_t-\\mu_+$ cancel, leaving\n",
    "$$\\Delta_t = \\frac{\\delta^2\\mu^2}{2\\sigma^2}+\\frac{\\delta\\mu}{\\sigma}\\epsilon_t \\tag{5}$$\n",
    "where $\\delta\\mu=\\mu_+-\\mu_-$. If we take $\\mu_\\pm=\\pm\\mu$, then $\\delta\\mu=2\\mu$, and\n",
    "$$\\Delta_t=2\\frac{\\mu^2}{\\sigma^2}+2\\frac{\\mu}{\\sigma}\\epsilon_t$$\n",
    "\n",
    "The first term is a constant *drift*, and the second term is a random *diffusion*.\n",
    "\n",
    "The SPRT says that we should add up these evidences, $L_T=\\sum_{t=1}^T \\Delta_t$. Note that the $\\Delta_t$ are independent. Recall that for independent random variables, the mean of a sum is the sum of the means. And the variance of a sum is the sum of the variances. \n",
    "\n",
    "</details>\n",
    "\n",
    "Adding these $\\Delta_t$ over time gives\n",
    "$$L_T\\sim\\mathcal{N}\\left(2\\frac{\\mu^2}{\\sigma^2}T,\\ 4\\frac{\\mu^2}{\\sigma^2}T\\right)=\\mathcal{N}(bT,c^2T)$$\n",
    "as claimed. The log-likelihood ratio $L_t$ is a biased random walk --- normally distributed with a time-dependent mean and variance. This is the Drift Diffusion Model.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "##  Coding Exercise 1.1: Simulating an SPRT model\n",
    "\n",
    "Let's now generate simulated data with $s=+1$ and see if the SPRT can infer the state correctly.\n",
    "\n",
    "We will implement a function `simulate_SPRT_fixedtime`, which will generate measurements based on $\\mu$, $\\sigma$, and the true state. It will then accumulate evidence over the time steps and output a decision on the state. The decision will be the state that is more likely according to the accumulated evidence. We will use the helper function `log_likelihood_ratio`, implemented in the next cell, which computes the log of the likelihood of the state being 1 divided by the likelihood of the state being -1. \n",
    "\n",
    "**Your coding tasks are:**\n",
    "\n",
    "**Step 1**: accumulate evidence.\n",
    "\n",
    "**Step 2**: make a decision at the last time point.\n",
    "\n",
    "We will then visualize 10 simulations of the DDM. In the next exercise you'll see how the parameters affect performance.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Execute this cell to enable the helper function `log_likelihood_ratio`\n",
    "\n",
    "def log_likelihood_ratio(Mvec, p0, p1):\n",
    "  \"\"\"Given a sequence(vector) of observed data, calculate the log of\n",
    "  likelihood ratio of p1 and p0\n",
    "\n",
    "  Args:\n",
    "    Mvec (numpy vector):           A vector of scalar measurements\n",
    "    p0 (Gaussian random variable): A normal random variable with `logpdf'\n",
    "                                    method\n",
    "    p1 (Gaussian random variable): A normal random variable with `logpdf`\n",
    "                                    method\n",
    "\n",
    "  Returns:\n",
    "    llvec: a vector of log likelihood ratios for each input data point\n",
    "  \"\"\"\n",
    "  return p1.logpdf(Mvec) - p0.logpdf(Mvec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def simulate_SPRT_fixedtime(mu, sigma, stop_time, true_dist = 1):\n",
    "  \"\"\"Simulate a Sequential Probability Ratio Test with fixed time stopping\n",
    "  rule. Two observation models are 1D Gaussian distributions N(1,sigma^2) and\n",
    "  N(-1,sigma^2).\n",
    "\n",
    "  Args:\n",
    "    mu (float): absolute mean value of the symmetric observation distributions\n",
    "    sigma (float): Standard deviation of observation models\n",
    "    stop_time (int): Number of samples to take before stopping\n",
    "    true_dist (1 or -1): Which state is the true state.\n",
    "\n",
    "  Returns:\n",
    "    evidence_history (numpy vector): the history of cumulated evidence given\n",
    "                                      generated data\n",
    "    decision (int): 1 for s = 1, -1 for s = -1\n",
    "    Mvec (numpy vector): the generated sequences of measurement data in this trial\n",
    "  \"\"\"\n",
    "\n",
    "  #################################################\n",
    "  ## TODO for students ##\n",
    "  # Fill out function and remove\n",
    "  raise NotImplementedError(\"Student exercise: complete simulate_SPRT_fixedtime\")\n",
    "  #################################################\n",
    "\n",
    "  # Set means of observation distributions\n",
    "  assert mu > 0, \"Mu should be > 0\"\n",
    "  mu_pos = mu\n",
    "  mu_neg = -mu\n",
    "\n",
    "  # Make observation distributions\n",
    "  p_pos = stats.norm(loc = mu_pos, scale = sigma)\n",
    "  p_neg = stats.norm(loc = mu_neg, scale = sigma)\n",
    "\n",
    "  # Generate a random sequence of measurements\n",
    "  if true_dist == 1:\n",
    "    Mvec = p_pos.rvs(size = stop_time)\n",
    "  else:\n",
    "    Mvec = p_neg.rvs(size = stop_time)\n",
    "\n",
    "  # Calculate log likelihood ratio for each measurement (delta_t)\n",
    "  ll_ratio_vec = log_likelihood_ratio(Mvec, p_neg, p_pos)\n",
    "\n",
    "  # STEP 1: Calculate accumulated evidence (S) given a time series of evidence (hint: np.cumsum)\n",
    "  evidence_history = ...\n",
    "\n",
    "  # STEP 2: Make decision based on the sign of the evidence at the final time.\n",
    "  decision = ...\n",
    "\n",
    "  return evidence_history, decision, Mvec\n",
    "\n",
    "\n",
    "# Set random seed\n",
    "np.random.seed(100)\n",
    "\n",
    "# Set model parameters\n",
    "mu = .2\n",
    "sigma = 3.5  # standard deviation for p+ and p-\n",
    "num_sample = 10  # number of simulations to run\n",
    "stop_time = 150 # number of steps before stopping\n",
    "\n",
    "# Simulate and visualize\n",
    "simulate_and_plot_SPRT_fixedtime(mu, sigma, stop_time, num_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "def simulate_SPRT_fixedtime(mu, sigma, stop_time, true_dist = 1):\n",
    "  \"\"\"Simulate a Sequential Probability Ratio Test with fixed time stopping\n",
    "  rule. Two observation models are 1D Gaussian distributions N(1,sigma^2) and\n",
    "  N(-1,sigma^2).\n",
    "\n",
    "  Args:\n",
    "    mu (float): absolute mean value of the symmetric observation distributions\n",
    "    sigma (float): Standard deviation of observation models\n",
    "    stop_time (int): Number of samples to take before stopping\n",
    "    true_dist (1 or -1): Which state is the true state.\n",
    "\n",
    "  Returns:\n",
    "    evidence_history (numpy vector): the history of cumulated evidence given\n",
    "                                      generated data\n",
    "    decision (int): 1 for s = 1, -1 for s = -1\n",
    "    Mvec (numpy vector): the generated sequences of measurement data in this trial\n",
    "  \"\"\"\n",
    "\n",
    "  # Set means of observation distributions\n",
    "  assert mu > 0, \"Mu should be > 0\"\n",
    "  mu_pos = mu\n",
    "  mu_neg = -mu\n",
    "\n",
    "  # Make observation distributions\n",
    "  p_pos = stats.norm(loc = mu_pos, scale = sigma)\n",
    "  p_neg = stats.norm(loc = mu_neg, scale = sigma)\n",
    "\n",
    "  # Generate a random sequence of measurements\n",
    "  if true_dist == 1:\n",
    "    Mvec = p_pos.rvs(size = stop_time)\n",
    "  else:\n",
    "    Mvec = p_neg.rvs(size = stop_time)\n",
    "\n",
    "  # Calculate log likelihood ratio for each measurement (delta_t)\n",
    "  ll_ratio_vec = log_likelihood_ratio(Mvec, p_neg, p_pos)\n",
    "\n",
    "  # STEP 1: Calculate accumulated evidence (S) given a time series of evidence (hint: np.cumsum)\n",
    "  evidence_history = np.cumsum(ll_ratio_vec)\n",
    "\n",
    "  # STEP 2: Make decision based on the sign of the evidence at the final time.\n",
    "  decision = np.sign(evidence_history[-1])\n",
    "\n",
    "  return evidence_history, decision, Mvec\n",
    "\n",
    "\n",
    "# Set random seed\n",
    "np.random.seed(100)\n",
    "\n",
    "# Set model parameters\n",
    "mu = .2\n",
    "sigma = 3.5  # standard deviation for p+ and p-\n",
    "num_sample = 10  # number of simulations to run\n",
    "stop_time = 150 # number of steps before stopping\n",
    "\n",
    "# Simulate and visualize\n",
    "simulate_and_plot_SPRT_fixedtime(mu, sigma, stop_time, num_sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Interactive Demo 1.2: Trajectories under the fixed-time stopping rule\n",
    "\n",
    "\n",
    "In the following demo, you can change the drift level (mu), noise level (sigma) in the observation model and the number of time steps before stopping (stop_time) using the sliders. You will then observe 10 simulations with those parameters. As in the previous exercise, the true state is +1.\n",
    " \n",
    "\n",
    "\n",
    "1.   Are you more likely to make the wrong decision (choose the incorrect state) with high or low noise?\n",
    "2. What happens when sigma is very small? Why?\n",
    "3.   Are you more likely to make the wrong decision (choose the incorrect state) with fewer or more time steps before stopping?\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Make sure you execute this cell to enable the widget!\n",
    "\n",
    "def simulate_SPRT_fixedtime(mu, sigma, stop_time, true_dist = 1):\n",
    "  \"\"\"Simulate a Sequential Probability Ratio Test with fixed time stopping\n",
    "  rule. Two observation models are 1D Gaussian distributions N(1,sigma^2) and\n",
    "  N(-1,sigma^2).\n",
    "\n",
    "  Args:\n",
    "    mu (float): absolute mean value of the symmetric observation distributions\n",
    "    sigma (float): Standard deviation of observation models\n",
    "    stop_time (int): Number of samples to take before stopping\n",
    "    true_dist (1 or -1): Which state is the true state.\n",
    "\n",
    "  Returns:\n",
    "    evidence_history (numpy vector): the history of cumulated evidence given\n",
    "                                      generated data\n",
    "    decision (int): 1 for s = 1, -1 for s = -1\n",
    "    Mvec (numpy vector): the generated sequences of measurement data in this trial\n",
    "  \"\"\"\n",
    "\n",
    "  # Set means of observation distributions\n",
    "  assert mu > 0, \"Mu should be >0\"\n",
    "  mu_pos = mu\n",
    "  mu_neg = -mu\n",
    "\n",
    "  # Make observation distributions\n",
    "  p_pos = stats.norm(loc = mu_pos, scale = sigma)\n",
    "  p_neg = stats.norm(loc = mu_neg, scale = sigma)\n",
    "\n",
    "  # Generate a random sequence of measurements\n",
    "  if true_dist == 1:\n",
    "    Mvec = p_pos.rvs(size = stop_time)\n",
    "  else:\n",
    "    Mvec = p_neg.rvs(size = stop_time)\n",
    "\n",
    "  # Calculate log likelihood ratio for each measurement (delta_t)\n",
    "  ll_ratio_vec = log_likelihood_ratio(Mvec, p_neg, p_pos)\n",
    "\n",
    "  # STEP 1: Calculate accumulated evidence (S) given a time series of evidence (hint: np.cumsum)\n",
    "  evidence_history = np.cumsum(ll_ratio_vec)\n",
    "\n",
    "  # STEP 2: Make decision based on the sign of the evidence at the final time.\n",
    "  decision = np.sign(evidence_history[-1])\n",
    "\n",
    "  return evidence_history, decision, Mvec\n",
    "\n",
    "np.random.seed(100)\n",
    "num_sample = 10\n",
    "\n",
    "@widgets.interact(mu=widgets.FloatSlider(min=0.1, max=5.0, step=0.1, value=0.5), sigma=(0.05, 10.0, 0.05), stop_time=(5, 500, 1))\n",
    "def plot(mu, sigma, stop_time):\n",
    "  simulate_and_plot_SPRT_fixedtime(mu, sigma, stop_time, num_sample, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove explanation\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "1) Higher noise, or higher sigma, means that the evidence accumulation varies up\n",
    "   and down more. You are more likely to make a wrong decision with high noise,\n",
    "   since the accumulated log likelihood ratio is more likely to be negative at the end\n",
    "   despite the true distribution being s = +1.\n",
    "\n",
    "2) When sigma is very small, the cumulated log likelihood ratios are basically a linear\n",
    "   diagonal line. This is because each new measurement will be very similar (since they are\n",
    "   being drawn from a Gaussian with a tiny standard deviation)\n",
    "\n",
    "3) You are more likely to be wrong with a small number of time steps before decision. There is\n",
    "   more change that the noise will affect the decision. We will explore this in the next section.\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Section 2: Analyzing the DDM: accuracy vs stopping time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "If you make a hasty decision (e.g., after only seeing 2 samples), or if observation noise buries the signal, you may see a negative accumulated log likelihood ratio and thus make a wrong decision. Let's plot how decision accuracy varies with the number of samples. Accuracy is the proportion of correct trials across our repeated simulations: $\\frac{\\# \\textrm{ correct decisions}}{\\# \\textrm{ total decisions}}$.\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Coding Exercise 2.1: The Speed/Accuracy Tradeoff\n",
    "\n",
    "We will fix our observation noise level. In this exercise you will implement a function to run many simulations for a certain stopping time, and calculate the _average decision accuracy_. We will then visualize the relation between average decision accuracy and stopping time. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def simulate_accuracy_vs_stoptime(mu, sigma, stop_time_list, num_sample, no_numerical=False):\n",
    "  \"\"\"Calculate the average decision accuracy vs. stopping time by running\n",
    "  repeated SPRT simulations for each stop time.\n",
    "\n",
    "  Args:\n",
    "      mu (float): absolute mean value of the symmetric observation distributions\n",
    "      sigma (float): standard deviation for observation model\n",
    "      stop_list_list (list-like object): a list of stopping times to run over\n",
    "      num_sample (int): number of simulations to run per stopping time\n",
    "      no_numerical (bool): flag that indicates the function to return analytical values only\n",
    "\n",
    "  Returns:\n",
    "      accuracy_list: a list of average accuracies corresponding to input\n",
    "                      `stop_time_list`\n",
    "      decisions_list: a list of decisions made in all trials\n",
    "  \"\"\"\n",
    "\n",
    "  #################################################\n",
    "  ## TODO for students##\n",
    "  # Fill out function and remove\n",
    "  raise NotImplementedError(\"Student exercise: complete simulate_accuracy_vs_stoptime\")\n",
    "  #################################################\n",
    "\n",
    "  # Determine true state (1 or -1)\n",
    "  true_dist = 1\n",
    "\n",
    "  # Set up tracker of accuracy and decisions\n",
    "  accuracies = np.zeros(len(stop_time_list),)\n",
    "  accuracies_analytical = np.zeros(len(stop_time_list),)\n",
    "  decisions_list = []\n",
    "\n",
    "  # Loop over stop times\n",
    "  for i_stop_time, stop_time in enumerate(stop_time_list):\n",
    "\n",
    "    if not no_numerical:\n",
    "      # Set up tracker of decisions for this stop time\n",
    "      decisions = np.zeros((num_sample,))\n",
    "\n",
    "      # Loop over samples\n",
    "      for i in range(num_sample):\n",
    "\n",
    "        # STEP 1: Simulate run for this stop time (hint: use output from last exercise)\n",
    "        _, decision, _= ...\n",
    "\n",
    "        # Log decision\n",
    "        decisions[i] = decision\n",
    "\n",
    "      # STEP 2: Calculate accuracy by averaging over trials\n",
    "      accuracies[i_stop_time] = ...\n",
    "\n",
    "      # Log decision\n",
    "      decisions_list.append(decisions)\n",
    "\n",
    "    # Calculate analytical accuracy\n",
    "    sigma_sum_gaussian = sigma / np.sqrt(stop_time)\n",
    "    accuracies_analytical[i_stop_time] = 0.5 + 0.5 * erf(mu / np.sqrt(2) / sigma_sum_gaussian)\n",
    "\n",
    "  return accuracies, accuracies_analytical, decisions_list\n",
    "\n",
    "# Set random seed\n",
    "np.random.seed(100)\n",
    "\n",
    "# Set parameters of model\n",
    "mu = 0.5\n",
    "sigma = 4.65  # standard deviation for observation noise\n",
    "num_sample = 100  # number of simulations to run for each stopping time\n",
    "stop_time_list = np.arange(1, 150, 10) # Array of stopping times to use\n",
    "\n",
    "\n",
    "# Calculate accuracies for each stop time\n",
    "accuracies, accuracies_analytical, _ = simulate_accuracy_vs_stoptime(mu, sigma, stop_time_list,\n",
    "                                                   num_sample)\n",
    "\n",
    "# Visualize\n",
    "plot_accuracy_vs_stoptime(mu, sigma, stop_time_list, accuracies_analytical, accuracies)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "\n",
    "def simulate_accuracy_vs_stoptime(mu, sigma, stop_time_list, num_sample, no_numerical=False):\n",
    "  \"\"\"Calculate the average decision accuracy vs. stopping time by running\n",
    "  repeated SPRT simulations for each stop time.\n",
    "\n",
    "  Args:\n",
    "      mu (float): absolute mean value of the symmetric observation distributions\n",
    "      sigma (float): standard deviation for observation model\n",
    "      stop_list_list (list-like object): a list of stopping times to run over\n",
    "      num_sample (int): number of simulations to run per stopping time\n",
    "      no_numerical (bool): flag that indicates the function to return analytical values only\n",
    "\n",
    "  Returns:\n",
    "      accuracy_list: a list of average accuracies corresponding to input\n",
    "                      `stop_time_list`\n",
    "      decisions_list: a list of decisions made in all trials\n",
    "  \"\"\"\n",
    "\n",
    "  # Determine true state (1 or -1)\n",
    "  true_dist = 1\n",
    "\n",
    "  # Set up tracker of accuracy and decisions\n",
    "  accuracies = np.zeros(len(stop_time_list),)\n",
    "  accuracies_analytical = np.zeros(len(stop_time_list),)\n",
    "  decisions_list = []\n",
    "\n",
    "  # Loop over stop times\n",
    "  for i_stop_time, stop_time in enumerate(stop_time_list):\n",
    "\n",
    "    if not no_numerical:\n",
    "      # Set up tracker of decisions for this stop time\n",
    "      decisions = np.zeros((num_sample,))\n",
    "\n",
    "      # Loop over samples\n",
    "      for i in range(num_sample):\n",
    "\n",
    "        # STEP 1: Simulate run for this stop time (hint: use output from last exercise)\n",
    "        _, decision, _= simulate_SPRT_fixedtime(mu, sigma, stop_time, true_dist)\n",
    "\n",
    "        # Log decision\n",
    "        decisions[i] = decision\n",
    "\n",
    "      # STEP 2: Calculate accuracy by averaging over trials\n",
    "      accuracies[i_stop_time] = np.sum(decisions == true_dist) / decisions.shape[0]\n",
    "\n",
    "      # Store the decisions\n",
    "      decisions_list.append(decisions)\n",
    "\n",
    "    # Calculate analytical accuracy\n",
    "    # S_t is a normal variable with SNR scale as sqrt(stop_time)\n",
    "    sigma_sum_gaussian = sigma / np.sqrt(stop_time)\n",
    "    accuracies_analytical[i_stop_time] = 0.5 + 0.5 * erf(mu / np.sqrt(2) / sigma_sum_gaussian)\n",
    "\n",
    "  return accuracies, accuracies_analytical, decisions_list\n",
    "\n",
    "\n",
    "# Set random seed\n",
    "np.random.seed(100)\n",
    "\n",
    "# Set parameters of model\n",
    "mu = 0.5\n",
    "sigma = 4.65  # standard deviation for observation noise\n",
    "num_sample = 100  # number of simulations to run for each stopping time\n",
    "stop_time_list = np.arange(1, 150, 10) # Array of stopping times to use\n",
    "\n",
    "# Calculate accuracies for each stop time\n",
    "accuracies, accuracies_analytical, _ = simulate_accuracy_vs_stoptime(mu, sigma, stop_time_list,\n",
    "                                                   num_sample)\n",
    "\n",
    "# Visualize\n",
    "plot_accuracy_vs_stoptime(mu, sigma, stop_time_list, accuracies_analytical, accuracies)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "In the figure above, we are plotting the simulated accuracies in orange. We can actually find an analytical equation for the average accuracy in this specific case, which we plot in blue. We will not dive into this analytical solution here but you can imagine that if you ran a bunch of different simulations and had the equivalent number of orange lines, the average of those would resemble the blue line. \n",
    "\n",
    "In the insets, we are showing the evidence distributions for the two states at a certain time point. Recall from Section 1 that the likelihood ratio at time $T$ for state of +1 is: $$L_T\\sim\\mathcal{N}\\left(2\\frac{\\mu^2}{\\sigma^2}T,\\ 4\\frac{\\mu^2}{\\sigma^2}T\\right)=\\mathcal{N}(bT,c^2T)$$\n",
    "\n",
    "If the state is -1, the mean is the reverse sign. We are plotting this Gaussian distribution for the state equaling -1 (dashed line) and the state equaling +1 (solid line). The area in red reflects the error rate - this region corresponds to $L_T$ being below 0 even though the true state is +1 so you would decide on the wrong state. As more time goes by, these distributions separate more and the error is lower."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Interactive Demo 2.2: Accuracy versus stop-time\n",
    "\n",
    "For this same visualization, now vary the mean $\\mu$ and standard deviation `sigma` of the evidence. What do you predict will the accuracy vs stopping time plot look like for low noise and high noise?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "#@markdown Make sure you execute this cell to enable the widget!\n",
    "def simulate_accuracy_vs_stoptime(mu, sigma, stop_time_list, num_sample, no_numerical=False):\n",
    "  \"\"\"Calculate the average decision accuracy vs. stopping time by running\n",
    "  repeated SPRT simulations for each stop time.\n",
    "\n",
    "  Args:\n",
    "      mu (float): absolute mean value of the symmetric observation distributions\n",
    "      sigma (float): standard deviation for observation model\n",
    "      stop_list_list (list-like object): a list of stopping times to run over\n",
    "      num_sample (int): number of simulations to run per stopping time\n",
    "      no_numerical (bool): flag that indicates the function to return analytical values only\n",
    "\n",
    "  Returns:\n",
    "      accuracy_list: a list of average accuracies corresponding to input\n",
    "                      `stop_time_list`\n",
    "      decisions_list: a list of decisions made in all trials\n",
    "  \"\"\"\n",
    "\n",
    "  # Determine true state (1 or -1)\n",
    "  true_dist = 1\n",
    "\n",
    "  # Set up tracker of accuracy and decisions\n",
    "  accuracies = np.zeros(len(stop_time_list),)\n",
    "  accuracies_analytical = np.zeros(len(stop_time_list),)\n",
    "  decisions_list = []\n",
    "\n",
    "  # Loop over stop times\n",
    "  for i_stop_time, stop_time in enumerate(stop_time_list):\n",
    "\n",
    "    if not no_numerical:\n",
    "      # Set up tracker of decisions for this stop time\n",
    "      decisions = np.zeros((num_sample,))\n",
    "\n",
    "      # Loop over samples\n",
    "      for i in range(num_sample):\n",
    "\n",
    "        # Simulate run for this stop time (hint: last exercise)\n",
    "        _, decision, _= simulate_SPRT_fixedtime(mu, sigma, stop_time, true_dist)\n",
    "\n",
    "        # Log decision\n",
    "        decisions[i] = decision\n",
    "\n",
    "      # Calculate accuracy\n",
    "      accuracies[i_stop_time] = np.sum(decisions == true_dist) / decisions.shape[0]\n",
    "      # Log decisions\n",
    "      decisions_list.append(decisions)\n",
    "\n",
    "    # Calculate analytical accuracy\n",
    "    sigma_sum_gaussian = sigma / np.sqrt(stop_time)\n",
    "    accuracies_analytical[i_stop_time] = 0.5 + 0.5 * erf(mu / np.sqrt(2) / sigma_sum_gaussian)\n",
    "\n",
    "\n",
    "  return accuracies, accuracies_analytical, decisions_list\n",
    "\n",
    "np.random.seed(100)\n",
    "num_sample = 100\n",
    "stop_time_list = np.arange(1, 100, 1)\n",
    "\n",
    "@widgets.interact\n",
    "def plot(mu=widgets.FloatSlider(min=0.1, max=5.0, step=0.1, value=1.0), sigma=(0.05, 10.0, 0.05)):\n",
    " # Calculate accuracies for each stop time\n",
    "  _, accuracies_analytical, _ = simulate_accuracy_vs_stoptime(mu, sigma, stop_time_list, num_sample, no_numerical=True)\n",
    "\n",
    "  # Visualize\n",
    "  plot_accuracy_vs_stoptime(mu, sigma, stop_time_list, accuracies_analytical)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove explanation\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "1) Low levels of noise results in higher accuracies generally, especially\n",
    "   at early stop times.\n",
    "\n",
    "2) High levels of noise results in lower accuracies generally.\n",
    "\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "**Application**\n",
    "\n",
    "We have looked at the drift diffusion model of decisions in the context of the fishing problem. There are lots of uses of this in neuroscience! As one example, a classic experimental task in neuroscience is the random dot kinematogram ([Newsome, Britten, Movshon 1989](https://www.nature.com/articles/341052a0.pdf)), in which a pattern of moving dots are moving in random directions but with some weak coherence that favors a net rightward or leftward motion. The observer must guess the direction. Neurons in the brain are informative about this task, and have responses that correlate with the choice, as predicted by the Drift Diffusion Model (Huk and Shadlen 2005).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "After you finish the other tutorials, come back to see Bonus material to learn about a different stopping rule for DDMs: a fixed threshold on confidence."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Summary\n",
    "\n",
    "Good job! By simulating Drift Diffusion Models, you have learnt how to:\n",
    "\n",
    "* Calculate individual sample evidence as the log likelihood ratio of two candidate models\n",
    "* Accumulate evidence from new data points, and compute posterior using recursive formula\n",
    "* Run repeated simulations to get an estimate of decision accuracies\n",
    "* Measure the speed-accuracy tradeoff"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Bonus "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "## Bonus Section 1: DDM with fixed thresholds on confidence"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "The next exercises consider a variant of the DDM with fixed confidence thresholds instead of fixed decision time. This may be a better description of neural integration. Please complete this material after you have finished the main content of all tutorials, if you would like extra information about this topic."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Bonus Coding Exercise 1.1, Coding: Simulating the DDM with fixed confidence thresholds\n",
    "\n",
    "*Referred to as exercise 3 in video*\n",
    "\n",
    "In this exercise, we will use thresholding as our stopping rule and observe the behavior of the DDM. \n",
    "\n",
    "With thresholding stopping rule, we define a desired error rate and will continue making measurements until that error rate is reached. Experimental evidence suggested that evidence accumulation and thresholding stopping strategy happens at neuronal level (see [this article](https://www.annualreviews.org/doi/full/10.1146/annurev.neuro.29.051605.113038) for further reading).\n",
    "\n",
    "* Complete the function `threshold_from_errorrate` to calculate the evidence threshold from desired error rate $\\alpha$ as described in the formulas below. The evidence thresholds $th_1$ and $th_0$ for $p_+$ and $p_-$ are opposite of each other as shown below, so you can just return the absolute value.\n",
    "$$\n",
    "\\begin{align}\n",
    " th_{L} &= \\log \\frac{\\alpha}{1-\\alpha} &= -th_{R} \\\\\n",
    " th_{R} &= \\log \\frac{1-\\alpha}{\\alpha} &= -th{_1}\\\\\n",
    " \\end{align}\n",
    " $$\n",
    "\n",
    "* Complete the function `simulate_SPRT_threshold` to simulate an SPRT with thresholding stopping rule given noise level and desired threshold \n",
    "\n",
    "* Run repeated simulations for a given noise level and a desired error rate visualize the DDM traces using our provided code \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def simulate_SPRT_threshold(mu, sigma, threshold , true_dist=1):\n",
    "  \"\"\"Simulate a Sequential Probability Ratio Test with thresholding stopping\n",
    "  rule. Two observation models are 1D Gaussian distributions N(1,sigma^2) and\n",
    "  N(-1,sigma^2).\n",
    "\n",
    "  Args:\n",
    "    mu (float): absolute mean value of the symmetric observation distributions\n",
    "    sigma (float): Standard deviation\n",
    "    threshold (float): Desired log likelihood ratio threshold to achieve\n",
    "                        before making decision\n",
    "\n",
    "  Returns:\n",
    "    evidence_history (numpy vector): the history of cumulated evidence given\n",
    "                                      generated data\n",
    "    decision (int): 1 for pR, 0 for pL\n",
    "    data (numpy vector): the generated sequences of data in this trial\n",
    "  \"\"\"\n",
    "  assert mu > 0, \"Mu should be > 0\"\n",
    "  muL = -mu\n",
    "  muR = mu\n",
    "\n",
    "  pL = stats.norm(muL, sigma)\n",
    "  pR = stats.norm(muR, sigma)\n",
    "\n",
    "  has_enough_data = False\n",
    "\n",
    "  data_history = []\n",
    "  evidence_history = []\n",
    "  current_evidence = 0.0\n",
    "\n",
    "  # Keep sampling data until threshold is crossed\n",
    "  while not has_enough_data:\n",
    "    if true_dist == 1:\n",
    "      Mvec = pR.rvs()\n",
    "    else:\n",
    "      Mvec = pL.rvs()\n",
    "\n",
    "    ########################################################################\n",
    "    # Insert your code here to:\n",
    "    #      * Calculate the log-likelihood ratio for the new sample\n",
    "    #      * Update the accumulated evidence\n",
    "    raise NotImplementedError(\"`simulate_SPRT_threshold` is incomplete\")\n",
    "    ########################################################################\n",
    "\n",
    "    # STEP 1: individual log likelihood ratios\n",
    "    ll_ratio = log_likelihood_ratio(...)\n",
    "\n",
    "    # STEP 2: accumulated evidence for this chunk\n",
    "    evidence_history.append(...)\n",
    "\n",
    "    # update the collection of all data\n",
    "    data_history.append(Mvec)\n",
    "    current_evidence = evidence_history[-1]\n",
    "\n",
    "    # check if we've got enough data\n",
    "    if abs(current_evidence) > threshold:\n",
    "      has_enough_data = True\n",
    "\n",
    "  data_history = np.array(data_history)\n",
    "  evidence_history = np.array(evidence_history)\n",
    "\n",
    "  # Make decision\n",
    "  if evidence_history[-1] >= 0:\n",
    "    decision = 1\n",
    "  elif evidence_history[-1] < 0:\n",
    "    decision = 0\n",
    "\n",
    "  return evidence_history, decision, data_history\n",
    "\n",
    "# Set parameters\n",
    "np.random.seed(100)\n",
    "mu = 1.0\n",
    "sigma = 2.8\n",
    "num_sample = 10\n",
    "log10_alpha = -3 # log10(alpha)\n",
    "alpha = np.power(10.0, log10_alpha)\n",
    "\n",
    "# Simulate and visualize\n",
    "simulate_and_plot_SPRT_fixedthreshold(mu, sigma, num_sample, alpha)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "def simulate_SPRT_threshold(mu, sigma, threshold , true_dist=1):\n",
    "  \"\"\"Simulate a Sequential Probability Ratio Test with thresholding stopping\n",
    "  rule. Two observation models are 1D Gaussian distributions N(1,sigma^2) and\n",
    "  N(-1,sigma^2).\n",
    "\n",
    "  Args:\n",
    "    mu (float): absolute mean value of the symmetric observation distributions\n",
    "    sigma (float): Standard deviation\n",
    "    threshold (float): Desired log likelihood ratio threshold to achieve\n",
    "                        before making decision\n",
    "\n",
    "  Returns:\n",
    "    evidence_history (numpy vector): the history of cumulated evidence given\n",
    "                                      generated data\n",
    "    decision (int): 1 for pR, 0 for pL\n",
    "    data (numpy vector): the generated sequences of data in this trial\n",
    "  \"\"\"\n",
    "  assert mu > 0, \"Mu should be > 0\"\n",
    "  muL = -mu\n",
    "  muR = mu\n",
    "\n",
    "  pL = stats.norm(muL, sigma)\n",
    "  pR = stats.norm(muR, sigma)\n",
    "\n",
    "  has_enough_data = False\n",
    "\n",
    "  data_history = []\n",
    "  evidence_history = []\n",
    "  current_evidence = 0.0\n",
    "\n",
    "  # Keep sampling data until threshold is crossed\n",
    "  while not has_enough_data:\n",
    "    if true_dist == 1:\n",
    "      Mvec = pR.rvs()\n",
    "    else:\n",
    "      Mvec = pL.rvs()\n",
    "\n",
    "    # STEP 1: individual log likelihood ratios\n",
    "    ll_ratio = log_likelihood_ratio(Mvec, pL, pR)\n",
    "\n",
    "    # STEP 2: accumulated evidence for this chunk\n",
    "    evidence_history.append(ll_ratio + current_evidence)\n",
    "\n",
    "    # update the collection of all data\n",
    "    data_history.append(Mvec)\n",
    "    current_evidence = evidence_history[-1]\n",
    "\n",
    "    # check if we've got enough data\n",
    "    if abs(current_evidence) > threshold:\n",
    "      has_enough_data = True\n",
    "\n",
    "  data_history = np.array(data_history)\n",
    "  evidence_history = np.array(evidence_history)\n",
    "\n",
    "  # Make decision\n",
    "  if evidence_history[-1] >= 0:\n",
    "    decision = 1\n",
    "  elif evidence_history[-1] < 0:\n",
    "    decision = 0\n",
    "\n",
    "  return evidence_history, decision, data_history\n",
    "\n",
    "\n",
    "# Set parameters\n",
    "np.random.seed(100)\n",
    "mu = 1.0\n",
    "sigma = 2.8\n",
    "num_sample = 10\n",
    "log10_alpha = -3 # log10(alpha)\n",
    "alpha = np.power(10.0, log10_alpha)\n",
    "\n",
    "# Simulate and visualize\n",
    "simulate_and_plot_SPRT_fixedthreshold(mu, sigma, num_sample, alpha)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Bonus Interactive Demo 1.2: DDM with fixed confidence threshold\n",
    "\n",
    "\n",
    "\n",
    "Play with difference values of `alpha` and `sigma` and observe how that affects the dynamics of Drift-Diffusion Model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Make sure you execute this cell to enable the widget!\n",
    "def simulate_SPRT_threshold(mu, sigma, threshold , true_dist=1):\n",
    "  \"\"\"Simulate a Sequential Probability Ratio Test with thresholding stopping\n",
    "  rule. Two observation models are 1D Gaussian distributions N(1,sigma^2) and\n",
    "  N(-1,sigma^2).\n",
    "\n",
    "  Args:\n",
    "    mu (float): absolute mean value of the symmetric observation distributions\n",
    "    sigma (float): Standard deviation\n",
    "    threshold (float): Desired log likelihood ratio threshold to achieve\n",
    "                        before making decision\n",
    "\n",
    "  Returns:\n",
    "    evidence_history (numpy vector): the history of cumulated evidence given\n",
    "                                      generated data\n",
    "    decision (int): 1 for pR, 0 for pL\n",
    "    data (numpy vector): the generated sequences of data in this trial\n",
    "  \"\"\"\n",
    "  assert mu > 0, \"Mu should be > 0\"\n",
    "  muL = -mu\n",
    "  muR = mu\n",
    "\n",
    "  pL = stats.norm(muL, sigma)\n",
    "  pR = stats.norm(muR, sigma)\n",
    "\n",
    "  has_enough_data = False\n",
    "\n",
    "  data_history = []\n",
    "  evidence_history = []\n",
    "  current_evidence = 0.0\n",
    "\n",
    "  # Keep sampling data until threshold is crossed\n",
    "  while not has_enough_data:\n",
    "    if true_dist == 1:\n",
    "      Mvec = pR.rvs()\n",
    "    else:\n",
    "      Mvec = pL.rvs()\n",
    "\n",
    "    # STEP 1: individual log likelihood ratios\n",
    "    ll_ratio = log_likelihood_ratio(Mvec, pL, pR)\n",
    "\n",
    "    # STEP 2: accumulated evidence for this chunk\n",
    "    evidence_history.append(ll_ratio + current_evidence)\n",
    "\n",
    "    # update the collection of all data\n",
    "    data_history.append(Mvec)\n",
    "    current_evidence = evidence_history[-1]\n",
    "\n",
    "    # check if we've got enough data\n",
    "    if abs(current_evidence) > threshold:\n",
    "      has_enough_data = True\n",
    "\n",
    "  data_history = np.array(data_history)\n",
    "  evidence_history = np.array(evidence_history)\n",
    "\n",
    "  # Make decision\n",
    "  if evidence_history[-1] >= 0:\n",
    "    decision = 1\n",
    "  elif evidence_history[-1] < 0:\n",
    "    decision = 0\n",
    "\n",
    "  return evidence_history, decision, data_history\n",
    "\n",
    "np.random.seed(100)\n",
    "num_sample = 10\n",
    "\n",
    "@widgets.interact\n",
    "def plot(mu=(0.1,5.0,0.1), sigma=(0.05, 10.0, 0.05), log10_alpha=(-8, -1, .1)):\n",
    "  alpha = np.power(10.0, log10_alpha)\n",
    "  simulate_and_plot_SPRT_fixedthreshold(mu, sigma, num_sample, alpha, verbose=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Bonus Coding Exercise 1.3: Speed/Accuracy Tradeoff Revisited\n",
    "\n",
    "The faster you make a decision, the lower your accuracy often is. This phenomenon is known as the **speed/accuracy tradeoff**. Humans can make this tradeoff in a wide range of situations, and many animal species, including ants, bees, rodents, and monkeys also show similar effects. \n",
    "\n",
    "To illustrate the speed/accuracy tradeoff under thresholding stopping rule, let's run some simulations under different thresholds and look at how average decision \"speed\" (1/length) changes with average decision accuracy. We use speed rather than accuracy because in real experiments, subjects can be incentivized to respond faster or slower; it's much harder to precisely control their decision time or error threshold. \n",
    "\n",
    "* Complete the function `simulate_accuracy_vs_threshold` to simulate and compute average accuracies vs. average decision lengths for a list of error thresholds. You will need to supply code to calculate average decision 'speed' from the lengths of trials. You should also calculate the overall accuracy across these trials. \n",
    "\n",
    "* We've set up a list of error thresholds. Run repeated simulations and collect average accuracy with average length for each error rate in this list, and use our provided code to visualize the speed/accuracy tradeoff. You should see a positive correlation between length and accuracy.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def simulate_accuracy_vs_threshold(mu, sigma, threshold_list, num_sample):\n",
    "  \"\"\"Calculate the average decision accuracy vs. average decision length by\n",
    "  running repeated SPRT simulations with thresholding stopping rule for each\n",
    "  threshold.\n",
    "\n",
    "  Args:\n",
    "      mu (float): absolute mean value of the symmetric observation distributions\n",
    "      sigma (float): standard deviation for observation model\n",
    "      threshold_list (list-like object): a list of evidence thresholds to run\n",
    "                                          over\n",
    "      num_sample (int): number of simulations to run per stopping time\n",
    "\n",
    "  Returns:\n",
    "      accuracy_list: a list of average accuracies corresponding to input\n",
    "                      `threshold_list`\n",
    "      decision_speed_list: a list of average decision speeds\n",
    "  \"\"\"\n",
    "  decision_speed_list = []\n",
    "  accuracy_list = []\n",
    "  for threshold in threshold_list:\n",
    "    decision_time_list = []\n",
    "    decision_list = []\n",
    "    for i in range(num_sample):\n",
    "      # run simulation and get decision of current simulation\n",
    "      _, decision, Mvec = simulate_SPRT_threshold(mu, sigma, threshold)\n",
    "      decision_time = len(Mvec)\n",
    "      decision_list.append(decision)\n",
    "      decision_time_list.append(decision_time)\n",
    "\n",
    "    ########################################################################\n",
    "    # Insert your code here to:\n",
    "    #      * Calculate mean decision speed given a list of decision times\n",
    "    #      * Hint: Think about speed as being inversely proportional\n",
    "    #        to decision_length. If it takes 10 seconds to make one decision,\n",
    "    #        our \"decision speed\" is 0.1 decisions per second.\n",
    "    #      * Calculate the decision accuracy\n",
    "    raise NotImplementedError(\"`simulate_accuracy_vs_threshold` is incomplete\")\n",
    "    ########################################################################\n",
    "    # Calculate and store average decision speed and accuracy\n",
    "    decision_speed = ...\n",
    "    decision_accuracy = ...\n",
    "    decision_speed_list.append(decision_speed)\n",
    "    accuracy_list.append(decision_accuracy)\n",
    "\n",
    "  return accuracy_list, decision_speed_list\n",
    "\n",
    "# Set parameters\n",
    "np.random.seed(100)\n",
    "mu = 1.0\n",
    "sigma = 3.75\n",
    "num_sample = 200\n",
    "alpha_list = np.logspace(-2, -0.1, 8)\n",
    "threshold_list = threshold_from_errorrate(alpha_list)\n",
    "\n",
    "# Simulate and visualize\n",
    "simulate_and_plot_accuracy_vs_threshold(mu, sigma, threshold_list, num_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "def simulate_accuracy_vs_threshold(mu, sigma, threshold_list, num_sample):\n",
    "  \"\"\"Calculate the average decision accuracy vs. average decision speed by\n",
    "  running repeated SPRT simulations with thresholding stopping rule for each\n",
    "  threshold.\n",
    "\n",
    "  Args:\n",
    "      mu (float): absolute mean value of the symmetric observation distributions\n",
    "      sigma (float): standard deviation for observation model\n",
    "      threshold_list (list-like object): a list of evidence thresholds to run\n",
    "                                          over\n",
    "      num_sample (int): number of simulations to run per stopping time\n",
    "\n",
    "  Returns:\n",
    "      accuracy_list: a list of average accuracies corresponding to input\n",
    "                      `threshold_list`\n",
    "      decision_speed_list: a list of average decision speeds\n",
    "  \"\"\"\n",
    "  decision_speed_list = []\n",
    "  accuracy_list = []\n",
    "  for threshold in threshold_list:\n",
    "    decision_time_list = []\n",
    "    decision_list = []\n",
    "    for i in range(num_sample):\n",
    "      # run simulation and get decision of current simulation\n",
    "      _, decision, Mvec = simulate_SPRT_threshold(mu, sigma, threshold)\n",
    "      decision_time = len(Mvec)\n",
    "      decision_list.append(decision)\n",
    "      decision_time_list.append(decision_time)\n",
    "\n",
    "    # Calculate and store average decision speed and accuracy\n",
    "    decision_speed = np.mean(1. / np.array(decision_time_list))\n",
    "    decision_accuracy = sum(decision_list) / len(decision_list)\n",
    "    decision_speed_list.append(decision_speed)\n",
    "    accuracy_list.append(decision_accuracy)\n",
    "\n",
    "  return accuracy_list, decision_speed_list\n",
    "\n",
    "\n",
    "# Set parameters\n",
    "np.random.seed(100)\n",
    "mu = 1.0\n",
    "sigma = 3.75\n",
    "num_sample = 200\n",
    "alpha_list = np.logspace(-2, -0.1, 8)\n",
    "threshold_list = threshold_from_errorrate(alpha_list)\n",
    "\n",
    "# Simulate and visualize\n",
    "simulate_and_plot_accuracy_vs_threshold(mu, sigma, threshold_list, num_sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Bonus Interactive demo 1.4: Speed/Accuracy with a threshold rule\n",
    "\n",
    "Manipulate the noise level `sigma` and observe how that affects the speed/accuracy tradeoff."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Make sure you execute this cell to enable the widget!\n",
    "def simulate_accuracy_vs_threshold(mu, sigma, threshold_list, num_sample):\n",
    "  \"\"\"Calculate the average decision accuracy vs. average decision speed by\n",
    "  running repeated SPRT simulations with thresholding stopping rule for each\n",
    "  threshold.\n",
    "\n",
    "  Args:\n",
    "      mu (float): absolute mean value of the symmetric observation distributions\n",
    "      sigma (float): standard deviation for observation model\n",
    "      threshold_list (list-like object): a list of evidence thresholds to run\n",
    "                                          over\n",
    "      num_sample (int): number of simulations to run per stopping time\n",
    "\n",
    "  Returns:\n",
    "      accuracy_list: a list of average accuracies corresponding to input\n",
    "                      `threshold_list`\n",
    "      decision_speed_list: a list of average decision speeds\n",
    "  \"\"\"\n",
    "  decision_speed_list = []\n",
    "  accuracy_list = []\n",
    "  for threshold in threshold_list:\n",
    "    decision_time_list = []\n",
    "    decision_list = []\n",
    "    for i in range(num_sample):\n",
    "      # run simulation and get decision of current simulation\n",
    "      _, decision, Mvec = simulate_SPRT_threshold(mu, sigma, threshold)\n",
    "      decision_time = len(Mvec)\n",
    "      decision_list.append(decision)\n",
    "      decision_time_list.append(decision_time)\n",
    "\n",
    "    # Calculate and store average decision speed and accuracy\n",
    "    decision_speed = np.mean(1. / np.array(decision_time_list))\n",
    "    decision_accuracy = sum(decision_list) / len(decision_list)\n",
    "    decision_speed_list.append(decision_speed)\n",
    "    accuracy_list.append(decision_accuracy)\n",
    "\n",
    "  return accuracy_list, decision_speed_list\n",
    "\n",
    "np.random.seed(100)\n",
    "num_sample = 100\n",
    "alpha_list = np.logspace(-2, -0.1, 8)\n",
    "threshold_list = threshold_from_errorrate(alpha_list)\n",
    "\n",
    "@widgets.interact\n",
    "def plot(mu=(0.1, 5.0, 0.1), sigma=(0.05, 10.0, 0.05)):\n",
    "  alpha = np.power(10.0, log10_alpha)\n",
    "  simulate_and_plot_accuracy_vs_threshold(mu, sigma, threshold_list, num_sample)"
   ]
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "W8_Tutorial1",
   "provenance": [],
   "toc_visible": true
  },
  "interpreter": {
   "hash": "9516f62da91337f10c2adbe814d9c63a4b08f8271333386358218606edb781e3"
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3.7.11 64-bit ('pw3': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
